Dirichlet Process Mixture Models

Clustering
Unsupervised Learning
Non-parametric
A non-parametric Bayesian clustering method that automatically determines the number of clusters.

General Principles

To discover group structures or clusters in data without pre-specifying the number of groups, we can use a Dirichlet Process Mixture Model (DPMM) Gershman and Blei (2012). This is a unsupervised clustering method πŸ›ˆ. Essentially, the model assumes the data is generated from a collection of different Gaussian distributions, and it simultaneously tries to figure out:

  1. How many clusters (K) exist: Unlike algorithms like K-Means, the DPMM infers the most probable number of clusters directly from the data.
  2. The properties of each cluster: For each inferred cluster, it estimates its location and its spread.
  3. The assignment of each data point: It determines the probability of each data point belonging to each cluster.

Considerations

Caution
  • A DPMM is a Bayesian model πŸ›ˆ that considers uncertainty in all its parameters. The core idea is to use the Dirichlet Process prior that allows for a potentially infinite number of clusters. In practice, we use a finite approximation where we cap the maximum number of clusters at K and use the Stick-Breaking Process πŸ›ˆ.

  • The key parameters and their priors are:

    • Concentration \alpha: This single parameter controls the tendency to create new clusters. A low Ξ± favors fewer, larger clusters, while a high Ξ± allows for many smaller clusters. We typically place a Gamma prior on \alpha to learn its value from the data.
    • Cluster Weights w: Generated via the Stick-Breaking process from \alpha. These are the probabilities of drawing a data point from any given cluster.

    • Cluster Parameters (\mu, \Sigma): Each potential cluster has a mean \mu and a covariance matrix \Sigma. If the data have multiple dimensions, we use a multivariate normal distribution (see chapter, 14). However, if the data is one-dimensional, we use a univariate normal distribution.

  • The model is often implemented in its marginalized form πŸ›ˆ. Instead of explicitly assigning each data point to a cluster, we integrate out this choice. This creates a smoother probability surface for the inference algorithm to explore, leading to much more efficient computation.

Example

Below is an example of a DPMM implemented in BI. The goal is to cluster a synthetic dataset into its underlying groups. The code first generates data with 4 distinct centers and then applies the DPMM to recover these clusters.

Code
from BI import bi, jnp 
from sklearn.datasets import make_blobs
import numpyro

m = bi(rand_seed = False)

# Generate synthetic data
data, true_labels = make_blobs(
    n_samples=500, centers=8, cluster_std=0.8,
    center_box=(-10,10), random_state=101
)
data_mean = jnp.mean(data, axis=0)
data_std = jnp.std(data, axis=0)*2

#  The model
def dpmm(data, K, data_mean, data_std):
    N, D = data.shape  # Number of features


    # 1) stick-breaking weights
    alpha = m.dist.gamma(1.0, 10.0,name='alpha')

    with m.dist.plate("beta_plate", K - 1):
        beta = m.dist.beta(1, alpha, name = "beta")

    w = numpyro.deterministic("w",m.models.dpmm.mix_weights(beta))

    # 2) component parameters
    with m.dist.plate("components", K):
        mu = m.dist.multivariate_normal(loc=data_mean, covariance_matrix=data_std*jnp.eye(D),name='mu')# shape (T, D)        
        sigma = m.dist.log_normal(0.0, 1.0,shape=(D,),event=1,name='sigma')# shape (T, D)
        Lcorr = m.dist.lkj_cholesky(dimension=D, concentration=1.0,name='Lcorr')# shape (T, D, D)

        scale_tril = sigma[..., None] * Lcorr  # shape (T, D, D)

    # 3) Latent cluster assignments for each data point
    m.dist.mixture_same_family(
        mixing_distribution=m.dist.categorical(probs=w, create_obj=True),
        component_distribution=m.dist.multivariate_normal(
            loc=mu, 
            scale_tril=scale_tril, 
            create_obj=True
        ),
        obs=data
    )

m.data_on_model = dict(data=data,K = 10, data_mean=data_mean, data_std=data_std)
m.fit(dpmm)  # Optimize model parameters through MCMC sampling
m.plot(X=data,sampler=m.sampler) # Prebuild plot function for GMM
jax.local_device_count 32
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<16:58,  1.02s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   2%|▏         | 15/1000 [00:01<00:57, 17.17it/s, 511 steps of size 2.03e-02. acc. prob=0.67]warmup:   2%|▏         | 21/1000 [00:01<00:43, 22.43it/s, 31 steps of size 3.83e-02. acc. prob=0.71] warmup:   3%|β–Ž         | 27/1000 [00:01<00:40, 23.81it/s, 255 steps of size 1.53e-02. acc. prob=0.71]warmup:   3%|β–Ž         | 32/1000 [00:01<00:38, 25.42it/s, 13 steps of size 1.40e-02. acc. prob=0.72] warmup:   4%|▍         | 42/1000 [00:01<00:25, 37.76it/s, 351 steps of size 3.60e-02. acc. prob=0.75]warmup:   5%|β–Œ         | 50/1000 [00:01<00:21, 44.39it/s, 191 steps of size 4.30e-02. acc. prob=0.75]warmup:   6%|β–Œ         | 60/1000 [00:02<00:16, 55.68it/s, 63 steps of size 4.38e-02. acc. prob=0.76] warmup:   7%|β–‹         | 68/1000 [00:02<00:22, 40.86it/s, 255 steps of size 6.96e-02. acc. prob=0.77]warmup:   8%|β–Š         | 75/1000 [00:02<00:20, 45.13it/s, 255 steps of size 2.96e-02. acc. prob=0.76]warmup:   8%|β–Š         | 81/1000 [00:02<00:19, 46.50it/s, 127 steps of size 6.81e-02. acc. prob=0.77]warmup:   9%|β–Š         | 87/1000 [00:02<00:20, 44.73it/s, 191 steps of size 3.44e-02. acc. prob=0.77]warmup:  10%|β–‰         | 98/1000 [00:02<00:15, 58.66it/s, 63 steps of size 1.79e-02. acc. prob=0.76] warmup:  10%|β–ˆ         | 105/1000 [00:03<00:25, 35.05it/s, 255 steps of size 5.21e-02. acc. prob=0.76]warmup:  11%|β–ˆ         | 111/1000 [00:03<00:25, 34.38it/s, 63 steps of size 8.57e-03. acc. prob=0.76] warmup:  12%|β–ˆβ–        | 116/1000 [00:03<00:30, 29.39it/s, 127 steps of size 2.80e-02. acc. prob=0.76]warmup:  12%|β–ˆβ–        | 120/1000 [00:03<00:28, 30.86it/s, 255 steps of size 5.06e-02. acc. prob=0.77]warmup:  12%|β–ˆβ–Ž        | 125/1000 [00:03<00:29, 29.50it/s, 1023 steps of size 1.66e-02. acc. prob=0.76]warmup:  13%|β–ˆβ–Ž        | 129/1000 [00:04<00:30, 28.99it/s, 63 steps of size 3.14e-02. acc. prob=0.77]  warmup:  13%|β–ˆβ–Ž        | 133/1000 [00:04<00:31, 27.69it/s, 511 steps of size 8.82e-03. acc. prob=0.76]warmup:  14%|β–ˆβ–Ž        | 137/1000 [00:04<00:38, 22.20it/s, 127 steps of size 4.79e-02. acc. prob=0.77]warmup:  14%|β–ˆβ–        | 145/1000 [00:04<00:27, 30.75it/s, 255 steps of size 1.85e-02. acc. prob=0.77]warmup:  15%|β–ˆβ–        | 149/1000 [00:04<00:27, 31.22it/s, 127 steps of size 4.09e-02. acc. prob=0.77]warmup:  16%|β–ˆβ–Œ        | 156/1000 [00:04<00:21, 38.65it/s, 255 steps of size 3.33e-02. acc. prob=0.77]warmup:  16%|β–ˆβ–Œ        | 162/1000 [00:05<00:19, 42.93it/s, 63 steps of size 1.17e-02. acc. prob=0.77] warmup:  17%|β–ˆβ–‹        | 168/1000 [00:05<00:18, 45.78it/s, 95 steps of size 9.15e-03. acc. prob=0.77]warmup:  17%|β–ˆβ–‹        | 173/1000 [00:05<00:22, 36.62it/s, 48 steps of size 1.07e-02. acc. prob=0.77]warmup:  18%|β–ˆβ–Š        | 178/1000 [00:05<00:29, 28.26it/s, 255 steps of size 3.73e-02. acc. prob=0.77]warmup:  18%|β–ˆβ–Š        | 184/1000 [00:05<00:24, 33.12it/s, 255 steps of size 3.85e-02. acc. prob=0.77]warmup:  19%|β–ˆβ–‰        | 188/1000 [00:05<00:24, 33.03it/s, 511 steps of size 1.22e-02. acc. prob=0.77]warmup:  19%|β–ˆβ–‰        | 192/1000 [00:05<00:24, 32.99it/s, 127 steps of size 1.29e-02. acc. prob=0.77]warmup:  20%|β–ˆβ–‰        | 196/1000 [00:06<00:35, 22.65it/s, 511 steps of size 3.14e-02. acc. prob=0.77]warmup:  20%|β–ˆβ–‰        | 199/1000 [00:06<00:45, 17.80it/s, 1023 steps of size 2.44e-02. acc. prob=0.77]warmup:  20%|β–ˆβ–ˆ        | 202/1000 [00:06<00:49, 16.00it/s, 127 steps of size 3.16e-02. acc. prob=0.77] warmup:  20%|β–ˆβ–ˆ        | 204/1000 [00:06<00:49, 16.22it/s, 255 steps of size 6.68e-02. acc. prob=0.78]warmup:  21%|β–ˆβ–ˆ        | 206/1000 [00:07<00:51, 15.54it/s, 1023 steps of size 1.75e-02. acc. prob=0.77]warmup:  21%|β–ˆβ–ˆ        | 208/1000 [00:07<01:00, 13.01it/s, 1023 steps of size 4.27e-02. acc. prob=0.77]warmup:  21%|β–ˆβ–ˆ        | 210/1000 [00:07<01:05, 12.01it/s, 1023 steps of size 1.42e-02. acc. prob=0.77]warmup:  21%|β–ˆβ–ˆ        | 212/1000 [00:07<01:12, 10.85it/s, 1023 steps of size 2.81e-02. acc. prob=0.77]warmup:  21%|β–ˆβ–ˆβ–       | 214/1000 [00:08<01:17, 10.15it/s, 1023 steps of size 7.69e-03. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 216/1000 [00:08<01:21,  9.64it/s, 1023 steps of size 1.78e-02. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 218/1000 [00:08<01:24,  9.27it/s, 1023 steps of size 9.07e-03. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 219/1000 [00:08<01:25,  9.09it/s, 1023 steps of size 1.28e-02. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 220/1000 [00:08<01:27,  8.96it/s, 1023 steps of size 1.97e-02. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 221/1000 [00:08<01:28,  8.81it/s, 1023 steps of size 9.69e-03. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 222/1000 [00:08<01:29,  8.71it/s, 1023 steps of size 1.49e-02. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 223/1000 [00:09<01:30,  8.59it/s, 1023 steps of size 2.17e-02. acc. prob=0.77]warmup:  22%|β–ˆβ–ˆβ–       | 224/1000 [00:09<01:30,  8.54it/s, 1023 steps of size 3.22e-02. acc. prob=0.78]warmup:  22%|β–ˆβ–ˆβ–Ž       | 225/1000 [00:09<01:30,  8.53it/s, 1023 steps of size 4.61e-02. acc. prob=0.78]warmup:  23%|β–ˆβ–ˆβ–Ž       | 226/1000 [00:09<01:30,  8.51it/s, 1023 steps of size 1.27e-02. acc. prob=0.77]warmup:  23%|β–ˆβ–ˆβ–Ž       | 227/1000 [00:09<01:30,  8.52it/s, 1023 steps of size 1.93e-02. acc. prob=0.77]warmup:  23%|β–ˆβ–ˆβ–Ž       | 229/1000 [00:09<01:16, 10.08it/s, 1023 steps of size 2.21e-02. acc. prob=0.77]warmup:  23%|β–ˆβ–ˆβ–Ž       | 232/1000 [00:09<00:51, 14.77it/s, 511 steps of size 3.07e-02. acc. prob=0.78] warmup:  23%|β–ˆβ–ˆβ–Ž       | 234/1000 [00:09<00:53, 14.33it/s, 255 steps of size 3.19e-02. acc. prob=0.78]warmup:  24%|β–ˆβ–ˆβ–Ž       | 236/1000 [00:10<01:05, 11.70it/s, 1023 steps of size 2.39e-02. acc. prob=0.78]warmup:  24%|β–ˆβ–ˆβ–       | 238/1000 [00:10<01:02, 12.23it/s, 511 steps of size 3.71e-02. acc. prob=0.78] warmup:  24%|β–ˆβ–ˆβ–       | 240/1000 [00:10<01:10, 10.84it/s, 1023 steps of size 2.40e-02. acc. prob=0.78]warmup:  24%|β–ˆβ–ˆβ–       | 242/1000 [00:10<01:15, 10.01it/s, 1023 steps of size 3.36e-02. acc. prob=0.78]warmup:  24%|β–ˆβ–ˆβ–       | 244/1000 [00:11<01:19,  9.51it/s, 1023 steps of size 3.97e-02. acc. prob=0.78]warmup:  25%|β–ˆβ–ˆβ–       | 246/1000 [00:11<01:10, 10.71it/s, 1023 steps of size 3.07e-02. acc. prob=0.78]warmup:  25%|β–ˆβ–ˆβ–       | 248/1000 [00:11<01:16,  9.85it/s, 1023 steps of size 3.80e-02. acc. prob=0.78]warmup:  25%|β–ˆβ–ˆβ–Œ       | 250/1000 [00:11<01:20,  9.34it/s, 1023 steps of size 6.00e-02. acc. prob=0.78]warmup:  26%|β–ˆβ–ˆβ–Œ       | 255/1000 [00:11<00:48, 15.43it/s, 255 steps of size 2.50e-02. acc. prob=0.77] warmup:  26%|β–ˆβ–ˆβ–Œ       | 259/1000 [00:11<00:38, 19.11it/s, 255 steps of size 2.92e-02. acc. prob=0.78]warmup:  26%|β–ˆβ–ˆβ–‹       | 263/1000 [00:12<00:35, 20.87it/s, 1023 steps of size 1.04e-02. acc. prob=0.77]warmup:  27%|β–ˆβ–ˆβ–‹       | 266/1000 [00:12<00:32, 22.52it/s, 127 steps of size 2.96e-02. acc. prob=0.78] warmup:  27%|β–ˆβ–ˆβ–‹       | 270/1000 [00:12<00:30, 24.07it/s, 511 steps of size 1.98e-02. acc. prob=0.78]warmup:  28%|β–ˆβ–ˆβ–Š       | 275/1000 [00:12<00:25, 28.60it/s, 255 steps of size 3.63e-02. acc. prob=0.78]warmup:  28%|β–ˆβ–ˆβ–Š       | 280/1000 [00:12<00:24, 29.53it/s, 511 steps of size 1.65e-02. acc. prob=0.78]warmup:  29%|β–ˆβ–ˆβ–Š       | 286/1000 [00:12<00:20, 35.40it/s, 127 steps of size 1.63e-02. acc. prob=0.78]warmup:  29%|β–ˆβ–ˆβ–‰       | 292/1000 [00:12<00:17, 40.91it/s, 255 steps of size 1.95e-02. acc. prob=0.78]warmup:  30%|β–ˆβ–ˆβ–‰       | 297/1000 [00:12<00:17, 40.89it/s, 511 steps of size 1.84e-02. acc. prob=0.78]warmup:  30%|β–ˆβ–ˆβ–ˆ       | 303/1000 [00:13<00:15, 44.74it/s, 127 steps of size 1.89e-02. acc. prob=0.78]warmup:  31%|β–ˆβ–ˆβ–ˆ       | 309/1000 [00:13<00:14, 46.35it/s, 255 steps of size 3.12e-02. acc. prob=0.78]warmup:  32%|β–ˆβ–ˆβ–ˆβ–      | 315/1000 [00:13<00:13, 49.70it/s, 127 steps of size 4.08e-02. acc. prob=0.78]warmup:  32%|β–ˆβ–ˆβ–ˆβ–      | 322/1000 [00:13<00:13, 51.26it/s, 255 steps of size 3.15e-02. acc. prob=0.78]warmup:  33%|β–ˆβ–ˆβ–ˆβ–Ž      | 328/1000 [00:13<00:14, 47.29it/s, 255 steps of size 2.72e-02. acc. prob=0.78]warmup:  33%|β–ˆβ–ˆβ–ˆβ–Ž      | 334/1000 [00:13<00:13, 47.66it/s, 255 steps of size 2.30e-02. acc. prob=0.78]warmup:  34%|β–ˆβ–ˆβ–ˆβ–      | 340/1000 [00:13<00:13, 50.76it/s, 63 steps of size 1.76e-02. acc. prob=0.78] warmup:  35%|β–ˆβ–ˆβ–ˆβ–      | 346/1000 [00:13<00:12, 50.76it/s, 63 steps of size 1.58e-02. acc. prob=0.78]warmup:  35%|β–ˆβ–ˆβ–ˆβ–Œ      | 352/1000 [00:14<00:13, 48.24it/s, 127 steps of size 1.20e-02. acc. prob=0.78]warmup:  36%|β–ˆβ–ˆβ–ˆβ–Œ      | 357/1000 [00:14<00:13, 46.39it/s, 127 steps of size 2.34e-02. acc. prob=0.78]warmup:  36%|β–ˆβ–ˆβ–ˆβ–‹      | 363/1000 [00:14<00:13, 47.96it/s, 255 steps of size 1.97e-02. acc. prob=0.78]warmup:  37%|β–ˆβ–ˆβ–ˆβ–‹      | 368/1000 [00:14<00:15, 40.16it/s, 511 steps of size 1.07e-02. acc. prob=0.78]warmup:  37%|β–ˆβ–ˆβ–ˆβ–‹      | 373/1000 [00:14<00:16, 38.86it/s, 127 steps of size 3.47e-02. acc. prob=0.78]warmup:  38%|β–ˆβ–ˆβ–ˆβ–Š      | 379/1000 [00:14<00:14, 41.59it/s, 255 steps of size 1.84e-02. acc. prob=0.78]warmup:  38%|β–ˆβ–ˆβ–ˆβ–Š      | 385/1000 [00:14<00:13, 44.93it/s, 127 steps of size 5.59e-02. acc. prob=0.78]warmup:  39%|β–ˆβ–ˆβ–ˆβ–‰      | 390/1000 [00:14<00:13, 46.13it/s, 127 steps of size 3.69e-02. acc. prob=0.78]warmup:  40%|β–ˆβ–ˆβ–ˆβ–‰      | 396/1000 [00:15<00:12, 48.92it/s, 255 steps of size 2.02e-02. acc. prob=0.78]warmup:  40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 402/1000 [00:15<00:12, 49.08it/s, 255 steps of size 3.00e-02. acc. prob=0.78]warmup:  41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 407/1000 [00:15<00:12, 48.53it/s, 127 steps of size 3.41e-02. acc. prob=0.78]warmup:  41%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 413/1000 [00:15<00:11, 49.57it/s, 255 steps of size 2.03e-02. acc. prob=0.78]warmup:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 418/1000 [00:15<00:11, 49.65it/s, 255 steps of size 1.83e-02. acc. prob=0.78]warmup:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 423/1000 [00:15<00:11, 48.95it/s, 255 steps of size 3.74e-02. acc. prob=0.78]warmup:  43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 429/1000 [00:15<00:11, 51.07it/s, 255 steps of size 2.85e-02. acc. prob=0.78]warmup:  44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 436/1000 [00:15<00:10, 54.67it/s, 127 steps of size 1.57e-02. acc. prob=0.78]warmup:  44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 442/1000 [00:15<00:10, 53.18it/s, 127 steps of size 3.32e-02. acc. prob=0.78]warmup:  45%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 448/1000 [00:16<00:10, 54.58it/s, 127 steps of size 3.13e-02. acc. prob=0.78]warmup:  46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 456/1000 [00:16<00:08, 60.85it/s, 63 steps of size 1.51e-02. acc. prob=0.78] warmup:  46%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 463/1000 [00:16<00:09, 55.64it/s, 255 steps of size 3.95e-02. acc. prob=0.78]warmup:  47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 469/1000 [00:16<00:09, 55.86it/s, 255 steps of size 3.61e-02. acc. prob=0.78]warmup:  48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 476/1000 [00:16<00:08, 58.63it/s, 255 steps of size 3.20e-02. acc. prob=0.78]warmup:  48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 484/1000 [00:16<00:09, 55.77it/s, 511 steps of size 3.93e-02. acc. prob=0.78]warmup:  49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 494/1000 [00:16<00:07, 66.28it/s, 63 steps of size 6.81e-02. acc. prob=0.78] sample:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 504/1000 [00:16<00:06, 74.69it/s, 63 steps of size 4.60e-02. acc. prob=0.55]sample:  52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 517/1000 [00:16<00:05, 89.38it/s, 63 steps of size 4.60e-02. acc. prob=0.86]sample:  53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 527/1000 [00:17<00:05, 90.29it/s, 63 steps of size 4.60e-02. acc. prob=0.86]sample:  54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 537/1000 [00:17<00:05, 92.54it/s, 63 steps of size 4.60e-02. acc. prob=0.88]sample:  55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 547/1000 [00:17<00:05, 88.79it/s, 63 steps of size 4.60e-02. acc. prob=0.88]sample:  56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 557/1000 [00:17<00:05, 78.22it/s, 127 steps of size 4.60e-02. acc. prob=0.86]sample:  57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 566/1000 [00:17<00:06, 68.91it/s, 191 steps of size 4.60e-02. acc. prob=0.88]sample:  57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 574/1000 [00:17<00:06, 66.60it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 586/1000 [00:17<00:05, 78.66it/s, 63 steps of size 4.60e-02. acc. prob=0.88] sample:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 595/1000 [00:17<00:05, 76.02it/s, 127 steps of size 4.60e-02. acc. prob=0.85]sample:  61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 606/1000 [00:18<00:04, 81.70it/s, 127 steps of size 4.60e-02. acc. prob=0.86]sample:  62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 615/1000 [00:18<00:04, 77.93it/s, 127 steps of size 4.60e-02. acc. prob=0.86]sample:  62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 624/1000 [00:18<00:05, 63.27it/s, 255 steps of size 4.60e-02. acc. prob=0.87]sample:  63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 631/1000 [00:18<00:06, 60.09it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 639/1000 [00:18<00:05, 63.77it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 648/1000 [00:18<00:05, 70.10it/s, 63 steps of size 4.60e-02. acc. prob=0.87] sample:  66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 660/1000 [00:18<00:04, 81.27it/s, 63 steps of size 4.60e-02. acc. prob=0.88]sample:  67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 671/1000 [00:18<00:03, 86.89it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 681/1000 [00:19<00:03, 87.95it/s, 191 steps of size 4.60e-02. acc. prob=0.88]sample:  69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 691/1000 [00:19<00:03, 89.61it/s, 63 steps of size 4.60e-02. acc. prob=0.88] sample:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 703/1000 [00:19<00:03, 95.95it/s, 63 steps of size 4.60e-02. acc. prob=0.89]sample:  71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 713/1000 [00:19<00:03, 90.05it/s, 127 steps of size 4.60e-02. acc. prob=0.89]sample:  72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 723/1000 [00:19<00:03, 79.81it/s, 127 steps of size 4.60e-02. acc. prob=0.89]sample:  73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 732/1000 [00:19<00:03, 82.11it/s, 63 steps of size 4.60e-02. acc. prob=0.88] sample:  74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 741/1000 [00:19<00:03, 67.08it/s, 447 steps of size 4.60e-02. acc. prob=0.87]sample:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 749/1000 [00:20<00:05, 46.33it/s, 383 steps of size 4.60e-02. acc. prob=0.87]sample:  76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 755/1000 [00:20<00:06, 37.41it/s, 895 steps of size 4.60e-02. acc. prob=0.87]sample:  76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 760/1000 [00:20<00:06, 36.41it/s, 63 steps of size 4.60e-02. acc. prob=0.87] sample:  76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 765/1000 [00:20<00:07, 33.18it/s, 511 steps of size 4.60e-02. acc. prob=0.87]sample:  77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 769/1000 [00:20<00:06, 33.42it/s, 255 steps of size 4.60e-02. acc. prob=0.87]sample:  77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 773/1000 [00:21<00:06, 33.24it/s, 383 steps of size 4.60e-02. acc. prob=0.87]sample:  78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 779/1000 [00:21<00:05, 37.46it/s, 191 steps of size 4.60e-02. acc. prob=0.87]sample:  79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 786/1000 [00:21<00:04, 43.36it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 793/1000 [00:21<00:04, 48.19it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 802/1000 [00:21<00:03, 54.71it/s, 255 steps of size 4.60e-02. acc. prob=0.87]sample:  81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 810/1000 [00:21<00:03, 60.09it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 819/1000 [00:21<00:02, 65.62it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 826/1000 [00:21<00:02, 63.26it/s, 63 steps of size 4.60e-02. acc. prob=0.87] sample:  83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 833/1000 [00:22<00:02, 60.42it/s, 63 steps of size 4.60e-02. acc. prob=0.87]sample:  84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 840/1000 [00:22<00:02, 60.30it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 850/1000 [00:22<00:02, 69.30it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 860/1000 [00:22<00:01, 75.76it/s, 63 steps of size 4.60e-02. acc. prob=0.87] sample:  87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 869/1000 [00:22<00:01, 78.85it/s, 63 steps of size 4.60e-02. acc. prob=0.88]sample:  88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 881/1000 [00:22<00:01, 86.77it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 890/1000 [00:22<00:01, 86.16it/s, 63 steps of size 4.60e-02. acc. prob=0.88] sample:  90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 900/1000 [00:22<00:01, 86.77it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 909/1000 [00:22<00:01, 74.57it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 918/1000 [00:23<00:01, 75.83it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 927/1000 [00:23<00:00, 78.17it/s, 127 steps of size 4.60e-02. acc. prob=0.88]sample:  94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 936/1000 [00:23<00:00, 81.15it/s, 63 steps of size 4.60e-02. acc. prob=0.88] sample:  95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 946/1000 [00:23<00:00, 84.63it/s, 63 steps of size 4.60e-02. acc. prob=0.88]sample:  96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 955/1000 [00:23<00:00, 79.92it/s, 63 steps of size 4.60e-02. acc. prob=0.88]sample:  96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 964/1000 [00:23<00:00, 74.27it/s, 127 steps of size 4.60e-02. acc. prob=0.87]sample:  97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 972/1000 [00:23<00:00, 74.07it/s, 63 steps of size 4.60e-02. acc. prob=0.86] sample:  98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 981/1000 [00:23<00:00, 77.91it/s, 127 steps of size 4.60e-02. acc. prob=0.86]sample:  99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 992/1000 [00:23<00:00, 84.50it/s, 127 steps of size 4.60e-02. acc. prob=0.86]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1000/1000 [00:24<00:00, 41.39it/s, 63 steps of size 4.60e-02. acc. prob=0.87]
Model found 8 clusters.

using BayesianInference
using PythonCall
numpyro = pyimport("numpyro")

m = importBI(rand_seed = false)

# 1. Generate Data
sk_datasets = pyimport("sklearn.datasets")
output = sk_datasets.make_blobs(n_samples=500, centers=8, cluster_std=0.8, center_box=(-10, 10), random_state=101)
data = output[0]
data_mean = jnp.mean(data, axis=0)
data_std = jnp.std(data, axis=0) * 2
m.data_on_model = pydict(data=data, K=10, data_mean = data_mean, data_std = data_std)


@BI function dpmm(data, K, data_mean , data_std)
    N, D = data.shape 

    alpha = m.dist.gamma(1.0, 10.0, name="alpha")

    beta = pywith(m.dist.plate("beta_plate", K - 1)) do _
        m.dist.beta(1, alpha, name = "beta")
    end

    w = numpyro.deterministic("w", m.models.dpmm.mix_weights(beta))

    mu, scale_tril = pywith(m.dist.plate("components", K)) do _
        mu_val = m.dist.multivariate_normal(
            loc=data_mean, 
            covariance_matrix=data_std * jnp.eye(D),
            name="mu"
        )
        
        sigma = m.dist.log_normal(0.0, 1.0, shape=(D,), event=1, name="sigma")
        Lcorr = m.dist.lkj_cholesky(dimension=D, concentration=1.0, name="Lcorr")
        scale_tril_inner = jnp.expand_dims(sigma, -1) * Lcorr
        (mu_val, scale_tril_inner)
    end
    
    m.dist.mixture_same_family(
        mixing_distribution=m.dist.categorical(probs=w, create_obj=true),
        component_distribution=m.dist.multivariate_normal(
            loc=mu, 
            scale_tril=scale_tril, 
            create_obj=true
        ),
        obs=data
    )
end

# 4. Run

m.fit(dpmm) 

@pyplot m.models.dpmm.plot_dpmm(m.data_on_model["data"], m.sampler)

Mathematical Details

The process involves two keys submodels. The first, aims to identify the location and scale of K potential clusters. The second, aims to identify which cluster is most likely to have generated a given data point.

\begin{aligned} \begin{pmatrix} Y_{i,1} \\ \vdots \\ Y_{i,D} \end{pmatrix} &\sim \text{MVN}\!\left( \begin{pmatrix} \mu_{z_i,1} \\ \vdots \\ \mu_{z_i,D} \end{pmatrix}, \, \Sigma_{z_i} \right) \\ \\ \begin{pmatrix} \mu_{k,1} \\ \vdots \\ \mu_{k,D} \end{pmatrix} &\sim \text{MVN}\!\left( \begin{pmatrix} A_{1} \\ \vdots \\ A_{D} \end{pmatrix}, \, B \right) \\ \\ \Sigma_k &= \text{Diag}(\sigma_k) \Omega_k \text{Diag}(\sigma_k) \\ \\ \sigma_{[k,d]} &\sim \text{HalfCauchy}(1) \\ \\ \Omega_k &\sim \text{LKJ}(2) \\ \\ z_{i} &\sim \text{Categorical}(\pi) \\ \\ \pi_{i}(\beta_{1:K}) &= \beta_i \prod_{j<K} (1-\beta_j) \\ \\ \beta_k &\sim \text{Beta}(1, \alpha) \\ \\ \alpha &\sim \text{Gamma}(1, 10) \\ \end{aligned}

Where :

  • \begin{pmatrix} Y_{[i,1]} \\ \vdots \\ Y_{[i,D]} \end{pmatrix} is the i-th observation of a D-dimensional data array.

  • \begin{pmatrix}\mu_{[k,1]} \\ \vdots \\ \mu_{[k,D]}\end{pmatrix} is the k-th parameter vector of dimension D.

  • \begin{pmatrix} A_{1} \\ \vdots \\ A_{D} \end{pmatrix} is a prior for the mean vector as derived from mean of the raw data.

  • B is the prior covariance of the cluster means, and is setup as a diagonal matrix with 0.1 along the diagonal.

  • \Sigma_k is the DxD covariance matrix of the k-th cluster (it is composed from \sigma_k and \Omega_k).

  • \text{Diag}(\sigma_k) is a diagonal matrix whose diagonal entries are the standard deviations: \text{Diag}(\sigma_k) = \begin{pmatrix} \sigma_{[k,1]} & 0 & \cdots & 0 \\ 0 & \sigma_{[k,2]} & & \vdots \\ \vdots & & \ddots & 0 \\ 0 & \cdots & 0 & \sigma_{[k,D]} \end{pmatrix}.

  • \sigma_{k} is a D-vector of standard deviations for the k-th cluster where each element, d, has a half-cauchy prior.

  • \Omega_k is a correlation matrix for the k-th cluster.

  • z_i is a latent variable that maps observation i to cluster k.

  • \pi is a vector of K cluster weights, some of which may be close to zero if the predicted number of clusters is less than the maximum number of clusters.

  • \beta_k: The set of K Beta-distributed random variables used in the stick-breaking process to construct the mixture weights.

  • \alpha: The concentration parameter, controlling the effective number of clusters.

Notes

Note
  • The primary advantage of the DPMM is the automatic inference of the number of clusters. The posterior distribution of the weights w reveals which components are β€œactive”, giving a probabilistic estimate of K.

  • Prior \alpha strongly influence the predicted number of clusters. Below are examples of this relationship:

Impact of Gamma Prior Hyperparameters on Cluster Counts
Shape Rate Behavior
1 15 Forces very few clusters
5 1 Encourages many small clusters
10 2 Same mean, less variance
2 0.5 Moderately many clusters
15 1 Explosive prior cluster count

Reference(s)

https://en.wikipedia.org/wiki/Dirichlet_process https://pyro.ai/examples/dirichlet_process_mixture.html

References

Gershman, Samuel J, and David M Blei. 2012. β€œA Tutorial on Bayesian Nonparametric Models.” Journal of Mathematical Psychology 56 (1): 1–12.